Experiments on Snake Activation#
To induce periodic extrapolation bias in neural networks, Ziyin, et al, 2020 proposed a simple activation function called “Snake activation” with the form \(x + \frac{1}{a}sin^2(ax)\) where \(a\) can be treated as a constant hyperparameter or learned parameter.
We’ve experimented on the Snake activation…
Extrapolation Experiment#
We generated synthetic data using sin(x) function, which we aim to learn. The blue colors are the training set which we will use to train our model. The orange colors are the test set which we will use to check if our model can generalize and did learn the sine function
Our inputs are the x-values (horizontal axis) and our targets are y = sin(x), which we will train our model to predict given x.
We show an animation below of the neural network parameters and the evolution of how it fits the data over its training (epochs).
In these experiments, we used
Xavier Normal initialization
Two hidden layers
512 neurons per hidden layer
learning rate = 0.0001
a = 30 for the activation functions
Interpolation Experiment#
Now what if we make this an interpolation problem instead of an extrapolation problem. In other words, what if we reverse the train and test set? Will the model be able to infer the sine wave in between the test data? We show below the sine wave colored by train (blue) and test (orange)
Note that we use the term “interpolate” loosely here. This is technically still an extrapolation problem since the distribution of the test set is not within the support of the distribution of the training set.
Extrapolation investigation#
def gen_data2(L=1000, prop_train=0.5, start=-30*torch.pi, end=30*torch.pi, reverse=False):
x = torch.linspace(start, end, L)
y = torch.sin(x) # + torch.sin(x/3) + torch.sin(x*3)
cnt_train = int(L * prop_train)
train_inds = [(L//8, (L//8 + cnt_train//2)), (L//8 + cnt_train//2 + L//8, L//8 + cnt_train//2 + L//8 + cnt_train//2)]
train_inds = (list(np.arange(train_inds[0][0], train_inds[0][1], 1))
+ list(np.arange(train_inds[1][0], train_inds[1][1], 1)))
test_inds = [i for i in np.arange(L) if i not in train_inds]
x_train = x[train_inds]
y_train = y[train_inds]
x_test = x[test_inds]
y_test = y[test_inds]
if reverse: # reverse train and test
return x_test, y_test, x_train, y_train
else:
return x_train, y_train, x_test, y_test
x_train, y_train, x_test, y_test = gen_data2(L=2000, prop_train=0.5, start=-50, end=50, reverse=False)
fig, axs = plt.subplots(1, 1, figsize=(12, 2))
plt.scatter(x_train, y_train, label='train', s=2)
plt.scatter(x_test, y_test, label='test', s=2)
plt.legend()
plt.tight_layout()
axs.spines[['right', 'top']].set_visible(False)
# for animation
fig, axs = plt.subplots(1, 1, figsize=(10, 2)) # dpi=200)
snake_a = 30
model = MLP(count_nodes, snake_a=snake_a)
model = model.to(device)
optim = torch.optim.Adam(model.parameters(),lr=lr)
# scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size= 1000)
anim = animation.FuncAnimation(fig,
animate,
frames=frames,
interval=1)
HTML(anim.to_jshtml(fps=10))
Fitting on a Decaying Signal#
Different coefficients 1 / b sin( a x)#
TSNE on learned features#
Take-aways#
We performed…